"""
Interest rate differential and carry trade variable construction.

Extensions beyond Koomen & Wicht (2020):
1. Real interest rate differentials (Carvalho, Ferrero & Nechio 2016)
2. Demographic distance measures between country pairs
3. Financial integration interactions (demographics × KAOPEN)
4. FX-hedged yield differentials (Japan carry trade)
"""

import pandas as pd
import numpy as np
from pathlib import Path

RAW_DIR = Path("/mnt/c/demographics_capital_flows/multilateral/data/raw")
PROCESSED_DIR = Path("/mnt/c/demographics_capital_flows/multilateral/data/processed")


def load_fred_rates():
    """Load FRED interest rate data."""
    fpath = RAW_DIR / "fred_rates.csv"
    if not fpath.exists():
        return None
    return pd.read_csv(fpath)


def load_imf_ifs_rates():
    """Load IMF IFS interest rate data."""
    fpath = RAW_DIR / "imf_ifs_rates.csv"
    if not fpath.exists():
        return None
    return pd.read_csv(fpath)


def combine_rate_sources(fred_df=None, ifs_df=None):
    """
    Combine FRED and IMF IFS rate data, preferring FRED for OECD countries.
    """
    from .macro import ISO2_TO_ISO3

    frames = []

    if fred_df is not None:
        frames.append(fred_df[['iso3', 'year', 'govt_bond_10y', 'short_rate_3m']].copy())

    if ifs_df is not None:
        ifs = ifs_df.copy()
        # Convert ISO2 to ISO3
        if 'iso2' in ifs.columns and 'iso3' not in ifs.columns:
            ifs['iso3'] = ifs['iso2'].map(ISO2_TO_ISO3)
            ifs = ifs.dropna(subset=['iso3'])

        # Map IFS variables to common names
        col_map = {}
        if 'figb_pa' in ifs.columns:
            col_map['figb_pa'] = 'govt_bond_10y'
        if 'fitb_pa' in ifs.columns:
            col_map['fitb_pa'] = 'short_rate_3m'
        if 'fpolm_pa' in ifs.columns:
            col_map['fpolm_pa'] = 'policy_rate'
        if 'filr_pa' in ifs.columns:
            col_map['filr_pa'] = 'lending_rate'

        ifs = ifs.rename(columns=col_map)
        rate_cols = [c for c in ['govt_bond_10y', 'short_rate_3m', 'policy_rate', 'lending_rate']
                     if c in ifs.columns]
        if rate_cols:
            frames.append(ifs[['iso3', 'year'] + rate_cols])

    if not frames:
        return None

    # Combine, preferring FRED data (listed first)
    combined = frames[0]
    for f in frames[1:]:
        combined = combined.merge(f, on=['iso3', 'year'], how='outer', suffixes=('', '_ifs'))
        # Fill missing FRED values with IFS values
        for col in ['govt_bond_10y', 'short_rate_3m']:
            ifs_col = f'{col}_ifs'
            if ifs_col in combined.columns:
                combined[col] = combined[col].fillna(combined[ifs_col])
                combined = combined.drop(columns=[ifs_col])

    return combined


def compute_real_rate_differentials(rates_df, macro_df):
    """
    Compute real interest rate differentials.

    real_rate = nominal_rate - inflation
    differential = real_rate_i - real_rate_world

    World rate is GDP-weighted average of available countries.
    """
    if rates_df is None:
        return None

    df = rates_df.merge(
        macro_df[['iso3', 'year', 'inflation', 'ngdp_usd']].dropna(subset=['inflation']),
        on=['iso3', 'year'],
        how='left'
    )

    # Compute real rates
    for nom_col, real_col in [
        ('govt_bond_10y', 'real_bond_10y'),
        ('short_rate_3m', 'real_short_3m'),
    ]:
        if nom_col in df.columns:
            df[real_col] = df[nom_col] - df['inflation']

    # Compute GDP-weighted world average real rate per year
    for real_col in ['real_bond_10y', 'real_short_3m']:
        if real_col not in df.columns:
            continue

        world_rate = df.dropna(subset=[real_col, 'ngdp_usd']).groupby('year').apply(
            lambda x: np.average(x[real_col], weights=x['ngdp_usd'].clip(lower=0.1)),
            include_groups=False
        ).rename(f'{real_col}_world')

        df = df.merge(world_rate.reset_index(), on='year', how='left')
        df[f'{real_col}_diff'] = df[real_col] - df[f'{real_col}_world']

    result = df.drop(columns=['inflation', 'ngdp_usd'], errors='ignore')
    return result


def compute_demographic_distance(demo_df):
    """
    Compute bilateral demographic distance measures.

    For each country pair (i,j) and year t:
    - Difference in old-age dependency ratios
    - Difference in median ages (approximated)
    - Euclidean distance in age-share space

    Returns panel with country-pair level variables.
    For the CA model, we aggregate to country-level by computing
    each country's average distance to major capital market countries.
    """
    from .demographics import AGE_GROUPS, G

    share_cols = [f'n_{g}' for g in range(1, G + 1)]
    needed = ['iso3', 'year'] + share_cols
    available = [c for c in needed if c in demo_df.columns]
    df = demo_df[available].dropna()

    # Compute summary measures per country-year
    # Old-age dependency ratio: (65+) / (15-64)
    old_cols = [f'n_{g}' for g in range(14, 18)]
    wa_cols = [f'n_{g}' for g in range(4, 14)]
    old_available = [c for c in old_cols if c in df.columns]
    wa_available = [c for c in wa_cols if c in df.columns]

    df['oadr'] = df[old_available].sum(axis=1) / df[wa_available].sum(axis=1)

    # Approximate median age from cumulative distribution
    midpoints = np.array([2.5, 7.5, 12.5, 17.5, 22.5, 27.5, 32.5,
                          37.5, 42.5, 47.5, 52.5, 57.5, 62.5, 67.5,
                          72.5, 77.5, 85.0])
    share_available = [c for c in share_cols if c in df.columns]
    share_matrix = df[share_available].values
    # Weighted mean as proxy for median
    df['approx_median_age'] = share_matrix @ midpoints[:len(share_available)]

    # For each country-year, compute distance to "capital market" countries
    capital_markets = ['USA', 'GBR', 'DEU', 'JPN', 'CHN', 'FRA', 'CAN', 'AUS']

    distances = []
    for year in df['year'].unique():
        year_data = df[df['year'] == year].set_index('iso3')
        cm_data = year_data[year_data.index.isin(capital_markets)]

        if len(cm_data) == 0:
            continue

        avg_cm_oadr = cm_data['oadr'].mean()
        avg_cm_median_age = cm_data['approx_median_age'].mean()

        for iso3 in year_data.index:
            distances.append({
                'iso3': iso3,
                'year': year,
                'oadr_dist_to_cm': year_data.loc[iso3, 'oadr'] - avg_cm_oadr,
                'median_age_dist_to_cm': year_data.loc[iso3, 'approx_median_age'] - avg_cm_median_age,
            })

    return pd.DataFrame(distances)


def compute_carry_trade_variables(rates_df):
    """
    Compute FX-hedged yield differentials for carry trade analysis.

    FX-hedged yield differential (j→i):
        (yield_j - yield_i) - hedging_cost_ij

    Hedging cost ≈ short-rate differential (covered interest parity):
        hedging_cost_ij ≈ short_rate_i - short_rate_j

    So: FX-hedged yield = (yield_j - yield_i) - (short_rate_i - short_rate_j)
                        = (yield_j - short_rate_j) - (yield_i - short_rate_i)
                        = term_spread_j - term_spread_i

    We compute this relative to Japan (primary carry trade originator)
    and relative to the US.
    """
    if rates_df is None:
        return None

    df = rates_df.copy()

    # Compute term spread
    if 'govt_bond_10y' in df.columns and 'short_rate_3m' in df.columns:
        df['term_spread'] = df['govt_bond_10y'] - df['short_rate_3m']

    if 'term_spread' not in df.columns:
        return None

    # Japan's term spread by year
    jpn = df[df['iso3'] == 'JPN'][['year', 'term_spread', 'short_rate_3m']].rename(
        columns={'term_spread': 'term_spread_jpn', 'short_rate_3m': 'short_rate_jpn'}
    )
    # US term spread by year
    usa = df[df['iso3'] == 'USA'][['year', 'term_spread', 'short_rate_3m']].rename(
        columns={'term_spread': 'term_spread_usa', 'short_rate_3m': 'short_rate_usa'}
    )

    df = df.merge(jpn, on='year', how='left')
    df = df.merge(usa, on='year', how='left')

    # FX-hedged yield from Japan's perspective (incentive to invest abroad)
    df['fx_hedged_vs_jpn'] = df['term_spread'] - df['term_spread_jpn']

    # FX-hedged yield from US perspective
    df['fx_hedged_vs_usa'] = df['term_spread'] - df['term_spread_usa']

    # Unhedged carry: pure short-rate differential
    df['carry_vs_jpn'] = df['short_rate_3m'] - df['short_rate_jpn']
    df['carry_vs_usa'] = df['short_rate_3m'] - df['short_rate_usa']

    return df


def construct_interaction_terms(panel_df):
    """
    Construct interaction terms between demographic variables and financial openness.

    Following Higgins (1998) and Carvalho et al. (2016):
    - Z_p × KAOPEN: demographic effects conditional on financial openness
    """
    df = panel_df.copy()

    z_cols = [c for c in df.columns if c.startswith('Z_')]
    if 'kaopen' in df.columns and z_cols:
        for z in z_cols:
            df[f'{z}_x_kaopen'] = df[z] * df['kaopen']

    return df


def assemble_rate_panel():
    """
    Assemble the full interest rate and carry trade variable panel.
    """
    print("Assembling interest rate panel...")

    fred = load_fred_rates()
    ifs = load_imf_ifs_rates()

    rates = combine_rate_sources(fred, ifs)
    if rates is not None:
        print(f"  Combined rates: {rates.shape}, {rates['iso3'].nunique()} countries")

        # Compute carry trade variables
        carry = compute_carry_trade_variables(rates)
        if carry is not None:
            rates = carry
            print(f"  Added carry trade variables")

        rates.to_csv(PROCESSED_DIR / "interest_rate_panel.csv", index=False)

    return rates


if __name__ == "__main__":
    rates = assemble_rate_panel()
    if rates is not None:
        print(f"\nRate panel shape: {rates.shape}")
        print(f"Countries: {rates['iso3'].nunique()}")
